{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8ea1e6eb",
   "metadata": {},
   "source": [
    "# Launching Multi-Node Training from a Jupyter Environment\n",
    "> Using the `notebook_launcher` to use Accelerate from inside a Jupyter Notebook"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "915d7904",
   "metadata": {},
   "source": [
    "## General Overview\n",
    "\n",
    "This notebook covers how to run the `cv_example.py` script as a Jupyter Notebook and train it on a distributed system. It will also cover the few specific requirements needed for ensuring your environment is configured properly, your data has been prepared properly, and finally how to launch training."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00dafb31",
   "metadata": {},
   "source": [
    "## Configuring the Environment\n",
    "\n",
    "Before any training can be performed, an accelerate config file must exist in the system. Usually this can be done by running the following in a terminal:\n",
    "\n",
    "```bash\n",
    "accelerate config\n",
    "```\n",
    "\n",
    "However, if general defaults are fine and you are *not* running on a TPU, accelerate has a utility to quickly write your GPU configuration into a config file via `write_basic_config`.\n",
    "\n",
    "The following cell will restart Jupyter after writing the configuration, as CUDA code was called to perform this. CUDA can't be initialized more than once (once for the single-GPU's notebooks use by default, and then what would be again when `notebook_launcher` is called). It's fine to debug in the notebook and have calls to CUDA, but remember that in order to finally train a full cleanup and restart will need to be performed, such as what is shown below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ae835e68",
   "metadata": {},
   "outputs": [],
   "source": [
    "#import os\n",
    "#from accelerate.utils import write_basic_config\n",
    "#write_basic_config() # Write a config file\n",
    "#os._exit(00) # Restart the notebook"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e067d11a",
   "metadata": {},
   "source": [
    "## Preparing the Dataset and Model\n",
    "\n",
    "Next you should prepare your dataset. As mentioned at earlier, great care should be taken when preparing the `DataLoaders` and model to make sure that **nothing** is put on *any* GPU. \n",
    "\n",
    "If you do, it is recommended to put that specific code into a function and call that from within the notebook launcher interface, which will be shown later. \n",
    "\n",
    "Make sure the dataset is downloaded based on the directions [here](https://github.com/huggingface/accelerate/tree/main/examples#simple-vision-example)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4805cb1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, re, torch, PIL\n",
    "import numpy as np\n",
    "\n",
    "from torch.optim.lr_scheduler import OneCycleLR\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor\n",
    "\n",
    "from accelerate import Accelerator\n",
    "from accelerate.utils import set_seed\n",
    "from timm import create_model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9938f8e4",
   "metadata": {},
   "source": [
    "First we'll create a function to extract the class name based on a file:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bcd4907f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beagle_32.jpg\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "data_dir = \"../../images\"\n",
    "fnames = os.listdir(data_dir)\n",
    "fname = fnames[0]\n",
    "print(fname)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39caa398",
   "metadata": {},
   "source": [
    "In the case here, the label is `beagle`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "43e28d35",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "def extract_label(fname):\n",
    "    stem = fname.split(os.path.sep)[-1]\n",
    "    return re.search(r\"^(.*)_\\d+\\.jpg$\", stem).groups()[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fab40fe9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'beagle'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "extract_label(fname)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bf6a733",
   "metadata": {},
   "source": [
    "Next we'll create a `Dataset` class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "72242e1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class PetsDataset(Dataset):\n",
    "    def __init__(self, file_names, image_transform=None, label_to_id=None):\n",
    "        self.file_names = file_names\n",
    "        self.image_transform = image_transform\n",
    "        self.label_to_id = label_to_id\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.file_names)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        fname = self.file_names[idx]\n",
    "        raw_image = PIL.Image.open(fname)\n",
    "        image = raw_image.convert(\"RGB\")\n",
    "        if self.image_transform is not None:\n",
    "            image = self.image_transform(image)\n",
    "        label = extract_label(fname)\n",
    "        if self.label_to_id is not None:\n",
    "            label = self.label_to_id[label]\n",
    "        return {\"image\": image, \"label\": label}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "222fc93a",
   "metadata": {},
   "source": [
    "And build our dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8a0f2499",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Grab all the image filenames\n",
    "fnames = [\n",
    "    os.path.join(data_dir, fname)\n",
    "    for fname in fnames\n",
    "    if fname.endswith(\".jpg\")\n",
    "]\n",
    "\n",
    "# Build the labels\n",
    "all_labels = [\n",
    "    extract_label(fname)\n",
    "    for fname in fnames\n",
    "]\n",
    "id_to_label = list(set(all_labels))\n",
    "id_to_label.sort()\n",
    "label_to_id = {lbl: i for i, lbl in enumerate(id_to_label)}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "343e66fd",
   "metadata": {},
   "source": [
    "> Note: This will be stored inside of a function as we'll be setting our seed during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "25430a5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataloaders(batch_size:int=64):\n",
    "    \"Builds a set of dataloaders with a batch_size\"\n",
    "    random_perm = np.random.permutation(len(fnames))\n",
    "    cut = int(0.8 * len(fnames))\n",
    "    train_split = random_perm[:cut]\n",
    "    eval_split = random_perm[:cut]\n",
    "    \n",
    "    # For training we use a simple RandomResizedCrop\n",
    "    train_tfm = Compose([\n",
    "        RandomResizedCrop((224, 224), scale=(0.5, 1.0)),\n",
    "        ToTensor()\n",
    "    ])\n",
    "    train_dataset = PetsDataset(\n",
    "        [fnames[i] for i in train_split],\n",
    "        image_transform=train_tfm,\n",
    "        label_to_id=label_to_id\n",
    "    )\n",
    "    \n",
    "    # For evaluation we use a deterministic Resize\n",
    "    eval_tfm = Compose([\n",
    "        Resize((224, 224)),\n",
    "        ToTensor()\n",
    "    ])\n",
    "    eval_dataset = PetsDataset(\n",
    "        [fnames[i] for i in eval_split],\n",
    "        image_transform=eval_tfm,\n",
    "        label_to_id=label_to_id\n",
    "    )\n",
    "    \n",
    "    # Instantiate dataloaders\n",
    "    train_dataloader = DataLoader(\n",
    "        train_dataset, \n",
    "        shuffle=True, \n",
    "        batch_size=batch_size,\n",
    "        num_workers=4\n",
    "    )\n",
    "    eval_dataloader = DataLoader(\n",
    "        eval_dataset,\n",
    "        shuffle=False,\n",
    "        batch_size=batch_size*2,\n",
    "        num_workers=4\n",
    "    )\n",
    "    return train_dataloader, eval_dataloader"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91084504",
   "metadata": {},
   "source": [
    "## Writing the Training Function\n",
    "\n",
    "Now we can build our training loop. `notebook_launcher` works by passing in a function to call that will be ran across the distributed system.\n",
    "\n",
    "Here is a basic training loop for our animal classification problem:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ebab7267",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim.lr_scheduler import CosineAnnealingLR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4366f90e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def training_loop(mixed_precision=\"fp16\", seed:int=42, batch_size:int=64):\n",
    "    set_seed(seed)\n",
    "    # Initialize accelerator\n",
    "    accelerator = Accelerator(mixed_precision=mixed_precision)\n",
    "    # Build dataloaders\n",
    "    train_dataloader, eval_dataloader = get_dataloaders(batch_size)\n",
    "    \n",
    "    # instantiate the model (we build the model here so that the seed also controls new weight initaliziations)\n",
    "    model = create_model(\"resnet50d\", pretrained=True, num_classes=len(label_to_id))\n",
    "    \n",
    "    # Freeze the base model\n",
    "    for param in model.parameters(): \n",
    "        param.requires_grad=False\n",
    "    for param in model.get_classifier().parameters():\n",
    "        param.requires_grad=True\n",
    "        \n",
    "    # We normalize the batches of images to be a bit faster\n",
    "    mean = torch.tensor(model.default_cfg[\"mean\"])[None, :, None, None]\n",
    "    std = torch.tensor(model.default_cfg[\"std\"])[None, :, None, None]\n",
    "    \n",
    "    # To make this constant available on the active device, we set it to the accelerator device\n",
    "    mean = mean.to(accelerator.device)\n",
    "    std = std.to(accelerator.device)\n",
    "    \n",
    "    # Intantiate the optimizer\n",
    "    optimizer = torch.optim.Adam(params=model.parameters(), lr = 3e-2/25)\n",
    "    \n",
    "    # Instantiate the learning rate scheduler\n",
    "    lr_scheduler = OneCycleLR(\n",
    "        optimizer=optimizer, \n",
    "        max_lr=3e-2, \n",
    "        epochs=5, \n",
    "        steps_per_epoch=len(train_dataloader)\n",
    "    )\n",
    "    \n",
    "    # Prepare everything\n",
    "    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n",
    "    # prepare method.\n",
    "    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n",
    "        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n",
    "    )\n",
    "    \n",
    "    # Now we train the model\n",
    "    for epoch in range(5):\n",
    "        model.train()\n",
    "        for step, batch in enumerate(train_dataloader):\n",
    "            # We could avoid this line since we set the accelerator with `device_placement=True`.\n",
    "            batch = {k: v.to(accelerator.device) for k, v in batch.items()}\n",
    "            inputs = (batch[\"image\"] - mean) / std\n",
    "            outputs = model(inputs)\n",
    "            loss = torch.nn.functional.cross_entropy(outputs, batch[\"label\"])\n",
    "            accelerator.backward(loss)\n",
    "            optimizer.step()\n",
    "            lr_scheduler.step()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "        model.eval()\n",
    "        accurate = 0\n",
    "        num_elems = 0\n",
    "        for _, batch in enumerate(eval_dataloader):\n",
    "            # We could avoid this line since we set the accelerator with `device_placement=True`.\n",
    "            batch = {k: v.to(accelerator.device) for k, v in batch.items()}\n",
    "            inputs = (batch[\"image\"] - mean) / std\n",
    "            with torch.no_grad():\n",
    "                outputs = model(inputs)\n",
    "            predictions = outputs.argmax(dim=-1)\n",
    "            accurate_preds = accelerator.gather(predictions) == accelerator.gather(batch[\"label\"])\n",
    "            num_elems += accurate_preds.shape[0]\n",
    "            accurate += accurate_preds.long().sum()\n",
    "\n",
    "        eval_metric = accurate.item() / num_elems\n",
    "        # Use accelerator.print to print only on the main process.\n",
    "        accelerator.print(f\"epoch {epoch}: {100 * eval_metric:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "117a7f5f",
   "metadata": {},
   "source": [
    "All that's left is to use the `notebook_launcher`.\n",
    "\n",
    "We pass in the function, the arguments (as a tuple), and the number of processes to train on. (See the [documentation](https://huggingface.co/docs/accelerate/package_reference/launchers#accelerate.notebook_launcher) for more information)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "88a096cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from accelerate import notebook_launcher"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "439fb3e6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Launching training on 2 GPUs.\n",
      "epoch 0: 88.12\n",
      "epoch 1: 91.73\n",
      "epoch 2: 92.58\n",
      "epoch 3: 93.90\n",
      "epoch 4: 94.71\n"
     ]
    }
   ],
   "source": [
    "args = (\"fp16\", 42, 64)\n",
    "notebook_launcher(training_loop, args, num_processes=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62fab197",
   "metadata": {},
   "source": [
    "And that's it!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b83c964a",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "This notebook showed how to perform distributed training from inside of a Jupyter Notebook. Some key notes to remember:\n",
    "\n",
    "- Make sure to save any code that use CUDA (or CUDA imports) for the function passed to `notebook_launcher`\n",
    "- Set the `num_processes` to be the number of devices used for training (such as number of GPUs, CPUs, TPUs, etc)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
